#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>

#define cudaCheckErrors(msg)                                                                                  \
  do {                                                                                                        \
    cudaError_t __err = cudaGetLastError();                                                                   \
    if (__err != cudaSuccess) {                                                                               \
      fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", msg, cudaGetErrorString(__err), __FILE__, __LINE__); \
      fprintf(stderr, "*** FAILED - ABORTING\n");                                                             \
      exit(1);                                                                                                \
    }                                                                                                         \
  } while (0)

template <>
struct std::hash<cudaIpcMemHandle_t> {
  size_t operator()(const cudaIpcMemHandle_t& handle) const {
    size_t hash = 0;
    uint8_t* ptr = (uint8_t*)&handle;
    assert(sizeof(uint8_t) == 1);
    for (int i = 0; i < sizeof(cudaIpcMemHandle_t); i++) {
      hash += *ptr;
      ptr++;
    }
    return hash;
  }
};

template <>
struct std::equal_to<cudaIpcMemHandle_t> {
  bool operator()(const cudaIpcMemHandle_t& lhs, const cudaIpcMemHandle_t& rhs) const {
    return (std::memcmp((void*)&lhs, (void*)&rhs, sizeof(cudaIpcMemHandle_t)) == 0);
  }
};

namespace {

namespace gpuipc {
// from: src/operator/nn/cudnn/nhwc_batch_norm_kernel.h
//  The number of threads per pixel.
const int THREADS_PER_PIXEL = 16;
// The number of elements per ldg.
const int ELEMENTS_PER_LDG = 4;
// The number of reducing ops, each uses its own space : mean, var, dscale, dbias
const int REDUCE_OPS = 4;
// Maximum block.y supported - limited due to buffer allocation
const int MAX_BLOCK_Y = 256;
const int MAX_OFFSET = REDUCE_OPS * MAX_BLOCK_Y;
const int BYTES_PER_ELEM = 4;
// Buffer size per sync step
const int SINGLE_SYNC_BUFFER_BYTES = MAX_OFFSET * THREADS_PER_PIXEL * 2 * ELEMENTS_PER_LDG * BYTES_PER_ELEM;
};  // namespace gpuipc

class IpcMemHandleRegistry {
 public:
  void* getPtr(const cudaIpcMemHandle_t& handle, int64_t offset) {
    if (registry_.count(handle) == 0) {
      registry_.insert(std::make_pair(handle, RegistryEntry()));
      registry_[handle].dev_ptr = ipcOpenMem(handle);
    }
    registry_[handle].ref_count++;
    return (((uint8_t*)registry_[handle].dev_ptr) + offset);
  }

  void releasePtr(const cudaIpcMemHandle_t& handle) {
    if (registry_.count(handle) == 0) {
    }
    if (--registry_[handle].ref_count == 0) {
      ipcCloseMem(registry_[handle].dev_ptr);
      registry_.erase(handle);
    }
  }

  struct RegistryEntry {
    void* dev_ptr;
    int ref_count;
    RegistryEntry() : dev_ptr(NULL), ref_count(0) {}
  };

 protected:
  std::unordered_map<cudaIpcMemHandle_t, RegistryEntry> registry_;

  void* ipcOpenMem(const cudaIpcMemHandle_t& handle) {
    void* data;
    cudaIpcOpenMemHandle(&data, handle, cudaIpcMemLazyEnablePeerAccess);
    cudaCheckErrors("ipc init");
    return data;
  }

  void ipcCloseMem(void* dev_ptr) {
    cudaIpcCloseMemHandle(dev_ptr);
    cudaCheckErrors("ipc close");
  }
};

}  // namespace

static IpcMemHandleRegistry ipc_mem_registry;

int64_t get_buffer_size(const int bn_sync_steps) { return bn_sync_steps * gpuipc::SINGLE_SYNC_BUFFER_BYTES; }

void* get_remote_data_ptr(const at::Tensor& handle, const int64_t offset) {
  cudaIpcMemHandle_t my_handle;
  memcpy((unsigned char*)(&my_handle), handle.data_ptr<uint8_t>(), sizeof(my_handle));
  return ipc_mem_registry.getPtr(my_handle, offset);
}

void close_remote_data(const at::Tensor& handle) {
  cudaIpcMemHandle_t my_handle;
  memcpy((unsigned char*)(&my_handle), handle.data_ptr<uint8_t>(), sizeof(my_handle));
  ipc_mem_registry.releasePtr(my_handle);
}

void* get_data_ptr(const at::Tensor& data) { return data.data_ptr<uint8_t>(); }
